{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import cv2\n",
    "from scipy.spatial import cKDTree\n",
    "\n",
    "from hfnet.evaluation.utils.descriptors import matches_cv2np, normalize as norm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2D-3D Matching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_2d = 2000\n",
    "n_3d = 3000\n",
    "\n",
    "d_2d = norm(np.random.rand(n_2d, 256) - .5)\n",
    "l_3d = np.random.randint(0, high=n_2d, size=(n_3d))\n",
    "d_3d = norm(d_2d[l_3d] + np.random.normal(scale=0.3, size=(n_3d, 256)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## OpenCV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1.31 s, sys: 25.1 ms, total: 1.34 s\n",
      "Wall time: 180 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "matcher = cv2.BFMatcher(cv2.NORM_L2)\n",
    "matches = matcher.knnMatch(d_2d.astype(np.float32), d_3d.astype(np.float32), k=2)\n",
    "matches1, matches2 = list(zip(*matches))\n",
    "(matches1, dist1) = matches_cv2np(matches1)\n",
    "(matches2, dist2) = matches_cv2np(matches2)\n",
    "good = (l_3d[matches1[:, 1]] == l_3d[matches2[:, 1]])\n",
    "good = good | (dist1/dist2 < 0.95)\n",
    "matches = matches1[good]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Numpy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 229,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 655 ms, sys: 1.64 s, total: 2.29 s\n",
      "Wall time: 378 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "dist = 2*(1 - d_2d @ d_3d.T)\n",
    "ind = np.argpartition(dist, 2, axis=-1)[:, :2]\n",
    "dist_nn = np.take_along_axis(dist, ind, axis=-1)\n",
    "labels_nn = l_3d[ind]\n",
    "\n",
    "thresh = 0.95**2\n",
    "match_ok = (labels_nn[:, 0] == labels_nn[:, 1])\n",
    "match_ok |= (dist_nn[:, 0] <= thresh*dist_nn[:, 1])\n",
    "matches = np.stack([np.where(match_ok)[0], ind[match_ok][:, 0]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "td_2d = torch.from_numpy(d_2d)\n",
    "td_3d = torch.from_numpy(d_3d)\n",
    "tl_3d = torch.from_numpy(l_3d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 203 ms, sys: 73.7 ms, total: 276 ms\n",
      "Wall time: 178 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "with torch.no_grad():\n",
    "    td_3d.t_()\n",
    "    dist = 2*(1 - td_2d @ td_3d)\n",
    "\n",
    "    dist_nn, ind = dist.topk(2, dim=-1, largest=False)\n",
    "    labels_nn = tl_3d[ind]\n",
    "\n",
    "    thresh = 0.95**2\n",
    "    match_ok = (labels_nn[:, 0] == labels_nn[:, 1])\n",
    "    match_ok |= (dist_nn[:, 0] <= thresh*dist_nn[:, 1])\n",
    "    matches = torch.stack([torch.nonzero(match_ok)[:, 0], ind[match_ok][:, 0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.jit.script\n",
    "def jit_matching(desc1, desc2, ratio_thresh, labels):\n",
    "    dist = 2*(1 - desc1 @ desc2.t())\n",
    "    dist_nn, ind = dist.topk(2, dim=-1, largest=False)\n",
    "    match_ok = (dist_nn[:, 0] <= (ratio_thresh**2)*dist_nn[:, 1])\n",
    "    labels_nn = labels[ind]\n",
    "    match_ok = match_ok | (labels_nn[:, 0] == labels_nn[:, 1])\n",
    "    matches = torch.stack(\n",
    "        [torch.nonzero(match_ok)[:, 0], ind[match_ok][:, 0]], dim=-1)\n",
    "    return matches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "jit_matching.save('pytorch_matching.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Global Matching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 20000\n",
    "k = 10\n",
    "query = norm(np.random.rand(1024) - .5)\n",
    "db = norm(np.random.rand(n, 1024) - .5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## kd-tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "index = cKDTree(db)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 50.7 ms, sys: 21.4 ms, total: 72.1 ms\n",
      "Wall time: 71.2 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "d, ind = index.query(query, k=k)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Numpy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 214,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 187 ms, sys: 576 ms, total: 763 ms\n",
      "Wall time: 97.2 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "dist = 2 * (1 - db @ query)\n",
    "ind = np.argpartition(dist, k)[:k]\n",
    "ind = ind[np.argsort(dist[ind])]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 215,
   "metadata": {},
   "outputs": [],
   "source": [
    "tquery = torch.from_numpy(query)\n",
    "tdb = torch.from_numpy(db)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 227,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 54.3 ms, sys: 0 ns, total: 54.3 ms\n",
      "Wall time: 10.4 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "with torch.no_grad():\n",
    "    dist = 2*(1 - tdb @ tquery)\n",
    "    _, ind = dist.topk(k, largest=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
